import cv2  # 用于图像读取和处理
import numpy as np  # 主要用于数值计算
import os  # 文件和目录操作
from os.path import exists  # 用于检测文件或目录是否存在
from imutils import paths  # 获取文件路径
import pickle  # 用于序列化和反序列化Python对象
from tqdm import tqdm  # 用于在循环中添加进度条
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input  # VGG16模型及预处理函数
from tensorflow.keras.preprocessing import image  # 用于加载和处理图像

import logging  # 日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # 配置日志格式

########## 步骤1：定义文件大小检查函数 ##########
def get_size(file):
    """
    获取文件大小并转换为兆字节表示
    """
    size_bytes = os.path.getsize(file)  # 获取文件的字节大小
    size_megabytes = size_bytes / (1024 * 1024)  # 字节转兆字节
    return size_megabytes  # 返回大小

########## 步骤2：定义特征与标签构建主函数 ##########
def createXY(train_folder, dest_folder, method='vgg', batch_size=64):
    """
    从图像数据创建特征 (X) 和标签 (y)，并保存到文件。
    支持 'vgg' 或 'flat' 两种模式。
    """
    x_file_path = os.path.join(dest_folder, "X.pkl")  # 特征文件路径
    y_file_path = os.path.join(dest_folder, "y.pkl")  # 标签文件路径

    # 如果文件已存在，则直接加载数据
    if exists(x_file_path) and exists(y_file_path):
        logging.info("X和y已经存在，直接读取")
        logging.info(f"X文件大小:{get_size(x_file_path):.2f}MB")
        logging.info(f"y文件大小:{get_size(y_file_path):.2f}MB")
        with open(x_file_path, 'rb') as f:
            X = pickle.load(f)  # 读取X
        with open(y_file_path, 'rb') as f:
            y = pickle.load(f)  # 读取y
        return X, y  # 返回数据

    ########## 步骤3：读取图像路径并初始化X和y ##########
    logging.info("读取所有图像，生成X和y")
    image_paths = list(paths.list_images(train_folder))  # 获取图像路径
    X = []  # 存储特征
    y = []  # 存储标签

    # 选择模型或扁平化处理
    if method == 'vgg':
        model = VGG16(weights='imagenet', include_top=False, pooling="max")
        logging.info("完成构建 VGG16 模型")
    elif method == 'flat':
        model = None  # 若为'flat'方法，无需模型

    ########## 步骤4：分批加载和处理图像 ##########
    num_batches = len(image_paths) // batch_size + (1 if len(image_paths) % batch_size else 0)
    for idx in tqdm(range(num_batches), desc="读取图像"):
        batch_images = []  # 当前批次图像
        batch_labels = []  # 当前批次标签

        # 获取当前批次图像范围
        start = idx * batch_size
        end = min((idx + 1) * batch_size, len(image_paths))

        # 加载和预处理当前批次的图像
        for i in range(start, end):
            image_path = image_paths[i]  # 获取图像路径
            if method == 'vgg':
                img = image.load_img(image_path, target_size=(224, 224))  # 加载并调整大小
                img = image.img_to_array(img)  # 转换为数组
            elif method == 'flat':
                img = cv2.imread(image_path, 0)  # 读取为灰度图
                img = cv2.resize(img, (32, 32))  # 调整尺寸
            batch_images.append(img)  # 添加到批次列表

            # 解析图像文件名中的标签
            label = 1 if image_path.split(os.path.sep)[-1].split(".")[0] == 'dog' else 0
            batch_labels.append(label)  # 添加标签

        # 转换批次图像为数组并处理
        batch_images = np.array(batch_images)
        if method == 'vgg':
            batch_images = preprocess_input(batch_images)
            batch_pixels = model.predict(batch_images, verbose=0)  # VGG16特征
        else:
            batch_pixels = batch_images.reshape((batch_images.shape[0], -1))  # 展平

        X.extend(batch_pixels)  # 添加特征到X
        y.extend(batch_labels)  # 添加标签到y

    ########## 步骤5：保存数据集 ##########
    logging.info(f"X.shape: {np.shape(X)}")
    logging.info(f"y.shape: {np.shape(y)}")
    with open(x_file_path, 'wb') as f:
        pickle.dump(X, f)  # 保存X
        logging.info(f"X文件大小: {get_size(x_file_path)} MB")
    with open(y_file_path, 'wb') as f:
        pickle.dump(y, f)  # 保存y
        logging.info(f"y文件大小: {get_size(y_file_path)} MB")

    return X, y  # 返回构建的特征和标签